extern crate alloc;

use burn::module::{Initializer, Param};
use burn::prelude::*;

use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;

#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;

#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;

#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;

#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;

// Import model info generated by build.rs (includes the albert_model module)
include!(concat!(env!("OUT_DIR"), "/model_info.rs"));

// Use the albert_model module from model_info.rs
use albert_model::Model;

#[derive(Debug, Module)]
struct TestData<B: Backend> {
    input_ids: Param<Tensor<B, 2, Int>>,
    attention_mask: Param<Tensor<B, 2, Int>>,
    token_type_ids: Param<Tensor<B, 2, Int>>,
    last_hidden_state: Param<Tensor<B, 3>>,
    pooler_output: Param<Tensor<B, 2>>,
}

impl<B: Backend> TestData<B> {
    fn new(device: &B::Device) -> Self {
        use burn::module::ParamId;
        // Initialize with correct shapes matching the test data
        // ALBERT base uses sequence_length=128, hidden_size=768
        // Note: Initializer only works for float tensors, Int tensors need manual init
        Self {
            input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
            attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
            token_type_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
            last_hidden_state: Initializer::Zeros.init([1, 128, 768], device),
            pooler_output: Initializer::Zeros.init([1, 768], device),
        }
    }
}

fn get_model_display_name(model_name: &str) -> &str {
    match model_name {
        "albert-base-v2" => "ALBERT Base v2",
        _ => model_name,
    }
}

fn main() {
    // MODEL_NAME is set at build time from ALBERT_MODEL env var
    let model_name = MODEL_NAME;
    let display_name = get_model_display_name(model_name);

    println!("========================================");
    println!("{} Burn Model Test", display_name);
    println!("========================================\n");

    // Check if artifacts exist
    let artifacts_dir = Path::new("artifacts");
    if !artifacts_dir.exists() {
        eprintln!("Error: artifacts directory not found!");
        eprintln!("Please run get_model.py first to download the model and test data.");
        eprintln!("Example: uv run get_model.py --model {}", model_name);
        std::process::exit(1);
    }

    // Check if model files exist for this specific model
    let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name));
    let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name));

    if !model_file.exists() || !test_data_file.exists() {
        eprintln!("Error: Model files not found for {}!", display_name);
        eprintln!("Please run: uv run get_model.py --model {}", model_name);
        eprintln!();
        eprintln!("Available models:");
        eprintln!("  - albert-base-v2");
        std::process::exit(1);
    }

    // Initialize the model (without weights for now)
    println!("Initializing {} model...", display_name);
    let start = Instant::now();
    let device = Default::default();
    let model: Model<MyBackend> = Model::default();
    let init_time = start.elapsed();
    println!("  Model initialized in {:.2?}", init_time);

    // Save model structure to file
    let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name));
    println!(
        "\nSaving model structure to {}...",
        model_txt_path.display()
    );
    let model_str = format!("{}", model);
    std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
    println!("  Model structure saved");

    // Load test data from PyTorch file
    println!("\nLoading test data from {}...", test_data_file.display());
    let start = Instant::now();
    let mut test_data = TestData::<MyBackend>::new(&device);
    let mut store = PytorchStore::from_file(&test_data_file);
    test_data.load_from(&mut store).expect("Failed to load test data");
    let load_time = start.elapsed();
    println!("  Data loaded in {:.2?}", load_time);

    // Get the input tensors from test data
    let input_ids = test_data.input_ids.val();
    let attention_mask = test_data.attention_mask.val();
    let token_type_ids = test_data.token_type_ids.val();

    println!("  Loaded input tensors:");
    println!("    input_ids shape: {:?}", input_ids.shape().dims);
    println!(
        "    attention_mask shape: {:?}",
        attention_mask.shape().dims
    );
    println!(
        "    token_type_ids shape: {:?}",
        token_type_ids.shape().dims
    );

    // Get the reference outputs from test data
    let reference_last_hidden = test_data.last_hidden_state.val();
    let reference_pooler = test_data.pooler_output.val();
    println!("  Loaded reference outputs:");
    println!(
        "    last_hidden_state shape: {:?}",
        reference_last_hidden.shape().dims
    );
    println!(
        "    pooler_output shape: {:?}",
        reference_pooler.shape().dims
    );

    // Run inference with the loaded input
    println!("\nRunning model inference with test input...");
    let start = Instant::now();
    let outputs = model.forward(input_ids, attention_mask, token_type_ids);
    let inference_time = start.elapsed();
    println!("  Inference completed in {:.2?}", inference_time);

    // ALBERT models typically return (last_hidden_state, pooler_output)
    // The outputs tuple should have 2 elements
    println!("\n  Model output shapes:");
    println!(
        "    output 0 (last_hidden_state): {:?}",
        outputs.0.shape().dims
    );
    println!("    output 1 (pooler_output): {:?}", outputs.1.shape().dims);

    // Compare outputs
    println!("\nComparing model outputs with reference data...");

    // Compare last_hidden_state
    println!("  Checking last_hidden_state...");
    if outputs
        .0
        .clone()
        .all_close(reference_last_hidden.clone(), Some(1e-4), Some(1e-4))
    {
        println!("    ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
    } else {
        println!("    ⚠ last_hidden_state differs from reference data!");

        let diff = outputs.0.clone() - reference_last_hidden.clone();
        let abs_diff = diff.abs();
        let max_diff = abs_diff.clone().max().into_scalar();
        let mean_diff = abs_diff.mean().into_scalar();

        println!("      Maximum absolute difference: {:.6}", max_diff);
        println!("      Mean absolute difference: {:.6}", mean_diff);
    }

    // Compare pooler_output
    println!("  Checking pooler_output...");
    if outputs
        .1
        .clone()
        .all_close(reference_pooler.clone(), Some(1e-4), Some(1e-4))
    {
        println!("    ✓ pooler_output matches reference data within tolerance (1e-4)!");
    } else {
        println!("    ⚠ pooler_output differs from reference data!");

        let diff = outputs.1.clone() - reference_pooler.clone();
        let abs_diff = diff.abs();
        let max_diff = abs_diff.clone().max().into_scalar();
        let mean_diff = abs_diff.mean().into_scalar();

        println!("      Maximum absolute difference: {:.6}", max_diff);
        println!("      Mean absolute difference: {:.6}", mean_diff);
    }

    println!("\n========================================");
    println!("Model test completed!");
    println!("========================================");
}
